• 0.1 Import the libraries.
  • 0.2 Load the data.
  • 0.3 Preprocess.
  • 0.4 Build the Neural Net.

Binary Classification using a deep neural network with ‘L’ hidden layers.

Notation:

  • Superscript [l] denotes a quantity associated with the lth layer.
    • Example: a[L] is the Lth layer activation. W[L] and b[L] are the Lth layer parameters.
  • Superscript (i) denotes a quantity associated with the ith example.
    • Example: x(i) is the ith training example.
  • Lowerscript i denotes the ith entry of a vector.
    • Example: ai[l] denotes the ith entry of the lth layer’s activations).

0.1 Import the libraries.

0.2 Load the data.

Train dataset.

ABCDEFGHIJ0123456789
Sepal.Length
<dbl>
Sepal.Width
<dbl>
Species
<int>
5.72.61
4.52.30
5.02.31
5.13.80
5.62.71
5.62.51
5.03.60
5.43.01
5.03.00
5.43.40

Test dataset.

ABCDEFGHIJ0123456789
Sepal.Length
<dbl>
Sepal.Width
<dbl>
Species
<int>
5.13.30
5.43.40
5.33.70
5.03.40
6.13.01
4.92.41
5.52.61
4.33.00
5.13.80
5.43.90

Plot the data.

0.3 Preprocess.

Scale all the values between 0 and 1.

## Shape of X (row, column): 
##  80 2 
## Shape of y (row, column) : 
##  80 1 
## Number of training samples: 
##  80 
## Shape of X_test (row, column): 
##  20 2 
## Shape of y_test (row, column) : 
##  20 1 
## Number of testing samples: 
##  20

Convert input and output to matrices. Change the shape of X and y by taking its transpose. This will make the matrix calculations slightly less verbose.

0.4 Build the Neural Net.

  • Initialize the parameters for a two-layer network and for an L-layer neural network.
  • Implement the forward propagation module.
    • Complete the LINEAR part of a layer’s forward propagation step (resulting in Z[l]).
    • Apply the ACTIVATION function (relu/sigmoid).
    • Combine the previous two steps into a new [LINEAR->ACTIVATION] forward function.
    • Stack the [LINEAR->RELU] forward function L-1 time (for layers 1 through L-1) and add a [LINEAR->SIGMOID] at the end (for the final layer L). This gives a new L_model_forward function.
  • Compute the loss.
  • Implement the backward propagation module.
    • Complete the LINEAR part of a layer’s backward propagation step.
    • Compute the gradient of the ACTIVATE function (relu_backward/sigmoid_backward)
    • Combine the previous two steps into a new [LINEAR->ACTIVATION] backward function.
    • Stack [LINEAR->RELU] backward L-1 times and add [LINEAR->SIGMOID] backward in a new L_model_backward function.
  • Finally update the parameters.
Figure 1: Deep Nnet Architecture.

Figure 1: Deep Nnet Architecture.

0.4.2 Forward Propagation.

We’ll take sample architecture to test our functions as 2 -> 3 -> 4 -> 1, where 2 is the number of input neurons and 1 is number of output neurons.

We will write three functions:

  • LINEAR
  • LINEAR -> ACTIVATION where ACTIVATION will be either ReLU or Sigmoid.
  • [LINEAR -> RELU] × (L-1) -> LINEAR -> SIGMOID (whole model)

Define the sigmoid function.

Define the relu function.

The equation for forward propagation is:

(4)Z[l]=W[l]A[l1]+b[l]

where A[0]=X.

## Shape of Z1: 
##  3 80 
## 
##  linear_cache:
## $A
## [1]  2 80
## 
## $W
## [1] 3 2
## 
## $b
## [1] 3 1
## Shape of Z2: 
##  4 80 
## 
##  linear_cache:
## $A
## [1]  3 80
## 
## $W
## [1] 4 3
## 
## $b
## [1] 4 1
## Shape of Z3: 
##  1 80 
## 
##  linear_cache:
## $A
## [1]  4 80
## 
## $W
## [1] 1 4
## 
## $b
## [1] 1 1

Now we will implement the forward propagation of the LINEAR->ACTIVATION layer.
Mathematical relation is: A[l]=g(Z[l])=g(W[l]A[l1]+b[l]) where the activation “g” can be sigmoid() or relu().

## Shape of A1: 
##  3 80 
## 
##  Linear Cache:
## $A
## [1]  2 80
## 
## $W
## [1] 3 2
## 
## $b
## [1] 3 1
## 
## Activation Cache:
## [1]  3 80
## Shape of A2: 
##  4 80 
## 
##  Linear Cache:
## $A
## [1]  3 80
## 
## $W
## [1] 4 3
## 
## $b
## [1] 4 1
## 
## Activation Cache:
## [1]  4 80
## Shape of A3: 
##  1 80 
## 
##  Linear Cache:
## $A
## [1]  4 80
## 
## $W
## [1] 1 4
## 
## $b
## [1] 1 1
## 
## Activation Cache:
## [1]  1 80
## $linear_cache
## $linear_cache$A
##                    [,1]      [,2]       [,3]       [,4]       [,5]      [,6]
## Sepal.Length  0.2982138 -1.548788 -0.7792037 -0.6252869  0.1442970  0.144297
## Sepal.Width  -1.0569860 -1.678742 -1.6787425  1.4300399 -0.8497339 -1.264238
##                    [,7]       [,8]       [,9]      [,10]      [,11]      [,12]
## Sepal.Length -0.7792037 -0.1635366 -0.7792037 -0.1635366  1.5295480  1.0677977
## Sepal.Width   1.0155356 -0.2279774 -0.2279774  0.6010313 -0.6424817 -0.4352295
##                   [,13]      [,14]      [,15]      [,16]      [,17]     [,18]
## Sepal.Length  0.4521305 -0.7792037  1.6834648  0.7599641  0.9138809  1.221714
## Sepal.Width  -0.8497339  0.8082834 -0.2279774 -0.4352295 -0.6424817 -1.678742
##                   [,19]      [,20]      [,21]        [,22]       [,23]
## Sepal.Length -0.6252869  0.2982138  0.9138809 -0.009619799 -1.08703727
## Sepal.Width   1.4300399 -0.4352295 -0.6424817 -1.471490339 -0.02072522
##                   [,24]      [,25]      [,26]      [,27]      [,28]      [,29]
## Sepal.Length -0.4713701  0.7599641  1.8373816  0.7599641 -1.7027044 -1.0870373
## Sepal.Width   0.8082834 -0.8497339 -0.2279774 -1.8859947  0.1865269  0.6010313
##                     [,30]       [,31]      [,32]      [,33]      [,34]
## Sepal.Length -0.009619799 -1.39487084 -0.7792037 -0.1635366  1.9912984
## Sepal.Width   2.259048549 -0.02072522 -2.3004990  1.6372921 -0.6424817
##                   [,35]      [,36]      [,37]      [,38]     [,39]      [,40]
## Sepal.Length -1.0870373 -1.7027044  0.2982138 -0.4713701 0.4521305 -0.9331205
## Sepal.Width   0.6010313 -0.2279774 -0.2279774  0.6010313 1.8445442  1.0155356
##                   [,41]      [,42]     [,43]      [,44]      [,45]      [,46]
## Sepal.Length -0.7792037 -0.9331205 2.2991319 -0.6252869 -1.3948708  0.1442970
## Sepal.Width   0.1865269 -0.2279774 0.1865269 -1.2642382  0.6010313 -0.2279774
##                   [,47]       [,48]     [,49]      [,50]      [,51]      [,52]
## Sepal.Length  1.3756312  1.83738159  1.067798 -0.7792037 -0.1635366 -1.0870373
## Sepal.Width  -0.4352295 -0.02072522 -1.885995  0.8082834  1.2227877 -0.2279774
##                   [,53]      [,54]      [,55]      [,56]      [,57]
## Sepal.Length -0.7792037 -0.7792037 -0.6252869  0.9138809  0.4521305
## Sepal.Width   0.6010313  0.3937791  0.8082834 -0.4352295 -0.8497339
##                     [,58]       [,59]      [,60]     [,61]      [,62]
## Sepal.Length -0.009619799 -0.93312049 -1.2409541 1.2217145 -1.3948708
## Sepal.Width   0.808283426 -0.02072522  0.1865269 0.3937791  0.1865269
##                    [,63]      [,64]      [,65]      [,66]        [,67]
## Sepal.Length  1.83738159  0.1442970  0.6060473  0.1442970 -0.009619799
## Sepal.Width  -0.02072522 -0.4352295 -0.2279774 -0.2279774 -1.264238179
##                   [,68]     [,69]      [,70]     [,71]      [,72]      [,73]
## Sepal.Length -0.4713701 0.2982138  0.4521305 0.6060473 -0.6252869 -1.0870373
## Sepal.Width   2.0517964 2.6735529 -1.0569860 0.1865269  0.8082834 -0.2279774
##                   [,74]       [,75]      [,76]      [,77]     [,78]      [,79]
## Sepal.Length -0.4713701  2.14521515  1.6834648 -0.6252869 0.2982138 -0.6252869
## Sepal.Width  -0.8497339 -0.02072522 -0.4352295  0.6010313 1.4300399  1.2227877
##                   [,80]
## Sepal.Length -0.4713701
## Sepal.Width   1.0155356
## 
## $linear_cache$W
##            [,1]      [,2]
## [1,] 0.03907962 0.5702809
## [2,] 0.29190193 0.8887148
## [3,] 0.61510693 0.9397041
## 
## $linear_cache$b
##      [,1]
## [1,]    0
## [2,]    0
## [3,]    0
## 
## 
## $activation_cache
##      [,1] [,2] [,3]      [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]     [,11]
## [1,]    0    0    0 0.7910885    0    0 0.5486896    0    0 0.3363657 0.0000000
## [2,]    0    0    0 1.0883752    0    0 0.6750705    0    0 0.4864088 0.0000000
## [3,]    0    0    0 0.9591960    0    0 0.4750093    0    0 0.4641991 0.3370929
##          [,12] [,13]     [,14]     [,15]      [,16] [,17] [,18]     [,19] [,20]
## [1,] 0.0000000     0 0.4304976 0.0000000 0.00000000     0     0 0.7910885     0
## [2,] 0.0000000     0 0.4908824 0.2887997 0.00000000     0     0 1.0883752     0
## [3,] 0.2478228     0 0.2802536 0.8212796 0.05847222     0     0 0.9591960     0
##      [,21] [,22] [,23]     [,24] [,25]     [,26] [,27]      [,28]     [,29]
## [1,]     0     0     0 0.4425276     0 0.0000000     0 0.03983171 0.3002757
## [2,]     0     0     0 0.5807396     0 0.3337284     0 0.00000000 0.2168371
## [3,]     0     0     0 0.4696042     0 0.9159549     0 0.00000000 0.0000000
##         [,30] [,31] [,32]     [,33]      [,34]     [,35] [,36] [,37]     [,38]
## [1,] 1.287916     0     0 0.9273255 0.00000000 0.3002757     0     0 0.3243357
## [2,] 2.004842     0     0 1.4073491 0.01028082 0.2168371     0     0 0.3965515
## [3,] 2.116920     0     0 1.4379776 0.62111875 0.0000000     0     0 0.2748485
##         [,39]     [,40]      [,41] [,42]     [,43] [,44]     [,45] [,46]
## [1,] 1.069577 0.5426746 0.07592177     0 0.1962220     0 0.2882456     0
## [2,] 1.771252 0.6301419 0.00000000     0 0.8368903     0 0.1269799     0
## [3,] 2.011434 0.3803341 0.00000000     0 1.5894921     0 0.0000000     0
##           [,47]      [,48] [,49]     [,50]     [,51] [,52]      [,53]     [,54]
## [1,] 0.00000000 0.05998498     0 0.4304976 0.6909416     0 0.31230567 0.1941137
## [2,] 0.01475447 0.51791642     0 0.4908824 1.0389730     0 0.30669433 0.1225063
## [3,] 0.43717333 1.11071058     0 0.2802536 1.0484662     0 0.08549794 0.0000000
##          [,55]     [,56] [,57]     [,58] [,59]      [,60]     [,61]      [,62]
## [1,] 0.4365126 0.0000000     0 0.4605727     0 0.05787674 0.2723088 0.05186173
## [2,] 0.5358110 0.0000000     0 0.7155254     0 0.00000000 0.7065781 0.00000000
## [3,] 0.3749289 0.1531475     0 0.7536300     0 0.00000000 1.1215209 0.00000000
##           [,63] [,64]     [,65] [,66] [,67]    [,68]    [,69] [,70]     [,71]
## [1,] 0.05998498     0 0.0000000     0     0 1.151679 1.536330     0 0.1300569
## [2,] 0.51791642     0 0.0000000     0     0 1.685868 2.463075     0 0.3426756
## [3,] 1.11071058     0 0.1585526     0     0 1.638138 2.695782     0 0.5480640
##          [,72] [,73] [,74]     [,75]     [,76]     [,77]     [,78]     [,79]
## [1,] 0.4365126     0     0 0.0720150 0.0000000 0.3183207 0.8271785 0.6728965
## [2,] 0.5358110     0     0 0.6077736 0.1046117 0.3516229 1.3579468 0.9041871
## [3,] 0.3749289     0     0 1.3000611 0.6265239 0.1801732 1.5272477 0.7644403
##          [,80]
## [1,] 0.5607196
## [2,] 0.7649277
## [3,] 0.6643599
## 
## $linear_cache
## $linear_cache$A
##      [,1] [,2] [,3]      [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]     [,11]
## [1,]    0    0    0 0.7910885    0    0 0.5486896    0    0 0.3363657 0.0000000
## [2,]    0    0    0 1.0883752    0    0 0.6750705    0    0 0.4864088 0.0000000
## [3,]    0    0    0 0.9591960    0    0 0.4750093    0    0 0.4641991 0.3370929
##          [,12] [,13]     [,14]     [,15]      [,16] [,17] [,18]     [,19] [,20]
## [1,] 0.0000000     0 0.4304976 0.0000000 0.00000000     0     0 0.7910885     0
## [2,] 0.0000000     0 0.4908824 0.2887997 0.00000000     0     0 1.0883752     0
## [3,] 0.2478228     0 0.2802536 0.8212796 0.05847222     0     0 0.9591960     0
##      [,21] [,22] [,23]     [,24] [,25]     [,26] [,27]      [,28]     [,29]
## [1,]     0     0     0 0.4425276     0 0.0000000     0 0.03983171 0.3002757
## [2,]     0     0     0 0.5807396     0 0.3337284     0 0.00000000 0.2168371
## [3,]     0     0     0 0.4696042     0 0.9159549     0 0.00000000 0.0000000
##         [,30] [,31] [,32]     [,33]      [,34]     [,35] [,36] [,37]     [,38]
## [1,] 1.287916     0     0 0.9273255 0.00000000 0.3002757     0     0 0.3243357
## [2,] 2.004842     0     0 1.4073491 0.01028082 0.2168371     0     0 0.3965515
## [3,] 2.116920     0     0 1.4379776 0.62111875 0.0000000     0     0 0.2748485
##         [,39]     [,40]      [,41] [,42]     [,43] [,44]     [,45] [,46]
## [1,] 1.069577 0.5426746 0.07592177     0 0.1962220     0 0.2882456     0
## [2,] 1.771252 0.6301419 0.00000000     0 0.8368903     0 0.1269799     0
## [3,] 2.011434 0.3803341 0.00000000     0 1.5894921     0 0.0000000     0
##           [,47]      [,48] [,49]     [,50]     [,51] [,52]      [,53]     [,54]
## [1,] 0.00000000 0.05998498     0 0.4304976 0.6909416     0 0.31230567 0.1941137
## [2,] 0.01475447 0.51791642     0 0.4908824 1.0389730     0 0.30669433 0.1225063
## [3,] 0.43717333 1.11071058     0 0.2802536 1.0484662     0 0.08549794 0.0000000
##          [,55]     [,56] [,57]     [,58] [,59]      [,60]     [,61]      [,62]
## [1,] 0.4365126 0.0000000     0 0.4605727     0 0.05787674 0.2723088 0.05186173
## [2,] 0.5358110 0.0000000     0 0.7155254     0 0.00000000 0.7065781 0.00000000
## [3,] 0.3749289 0.1531475     0 0.7536300     0 0.00000000 1.1215209 0.00000000
##           [,63] [,64]     [,65] [,66] [,67]    [,68]    [,69] [,70]     [,71]
## [1,] 0.05998498     0 0.0000000     0     0 1.151679 1.536330     0 0.1300569
## [2,] 0.51791642     0 0.0000000     0     0 1.685868 2.463075     0 0.3426756
## [3,] 1.11071058     0 0.1585526     0     0 1.638138 2.695782     0 0.5480640
##          [,72] [,73] [,74]     [,75]     [,76]     [,77]     [,78]     [,79]
## [1,] 0.4365126     0     0 0.0720150 0.0000000 0.3183207 0.8271785 0.6728965
## [2,] 0.5358110     0     0 0.6077736 0.1046117 0.3516229 1.3579468 0.9041871
## [3,] 0.3749289     0     0 1.3000611 0.6265239 0.1801732 1.5272477 0.7644403
##          [,80]
## [1,] 0.5607196
## [2,] 0.7649277
## [3,] 0.6643599
## 
## $linear_cache$W
##            [,1]      [,2]      [,3]
## [1,] 0.08013768 0.0625536 0.1121582
## [2,] 0.26414429 0.1112459 0.6665622
## [3,] 0.71345435 0.3634207 0.8498899
## [4,] 0.07357475 0.0597748 0.2804513
## 
## $linear_cache$b
##      [,1]
## [1,]    0
## [2,]    0
## [3,]    0
## [4,]    0
## 
## 
## $activation_cache
##      [,1] [,2] [,3]      [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]
## [1,]    0    0    0 0.2390594    0    0 0.1394750    0    0 0.1094459
## [2,]    0    0    0 0.9694026    0    0 0.5366553    0    0 0.4523776
## [3,]    0    0    0 1.7751546    0    0 1.0405052    0    0 0.8112707
## [4,]    0    0    0 0.3922693    0    0 0.2139389    0    0 0.1840082
##           [,11]      [,12] [,13]      [,14]     [,15]       [,16] [,17] [,18]
## [1,] 0.03780772 0.02779534     0 0.09663827 0.1101787 0.006558136     0     0
## [2,] 0.22469340 0.16518930     0 0.35512861 0.5795617 0.038975370     0     0
## [3,] 0.28649186 0.21062207     0 0.72372196 0.8029530 0.049694945     0     0
## [4,] 0.09453813 0.06950221     0 0.13961364 0.2475918 0.016398606     0     0
##          [,19] [,20] [,21] [,22] [,23]     [,24] [,25]     [,26] [,27]
## [1,] 0.2390594     0     0     0     0 0.1244604     0 0.1236077     0
## [2,] 0.9694026     0     0     0     0 0.4945164     0 0.6476668     0
## [3,] 1.7751546     0     0     0     0 0.9258879     0 0.8997446     0
## [4,] 0.3922693     0     0     0     0 0.1989735     0 0.2768292     0
##            [,28]      [,29]     [,30] [,31] [,32]     [,33]      [,34]
## [1,] 0.003192021 0.03762734 0.4660505     0     0 0.3236294 0.07030663
## [2,] 0.010521320 0.10343833 1.9742850     0     0 1.3600110 0.41515799
## [3,] 0.028418109 0.29303607 3.4466194     0     0 2.3951868 0.53161881
## [4,] 0.002930608 0.03505410 0.8082900     0     0 0.5556344 0.17480807
##           [,35] [,36] [,37]      [,38]     [,39]     [,40]       [,41] [,42]
## [1,] 0.03762734     0     0 0.08162374 0.4221104 0.1255639 0.006084195     0
## [2,] 0.10343833     0     0 0.31298976 1.8203134 0.4669614 0.020054302     0
## [3,] 0.29303607     0     0 0.60910470 3.1163019 0.9394222 0.054166717     0
## [4,] 0.03505410     0     0 0.12464831 0.7486794 0.1842589 0.005585925     0
##          [,43] [,44]      [,45] [,46]     [,47]     [,48] [,49]      [,50]
## [1,] 0.2463498     0 0.03104239     0 0.0499555 0.1617798     0 0.09663827
## [2,] 1.2044269     0 0.09026443     0 0.2930446 0.8138185     0 0.35512861
## [3,] 1.7950320     0 0.25179723     0 0.3769113 1.1749998     0 0.72372196
## [4,] 0.5102370     0 0.02879780     0 0.1234878 0.3468719     0 0.13961364
##          [,51] [,52]      [,53]      [,54]     [,55]      [,56] [,57]     [,58]
## [1,] 0.2379560     0 0.05380158 0.02321903 0.1105494 0.01717674     0 0.1661937
## [2,] 0.9969576     0 0.17360193 0.06490235 0.4248225 0.10208234     0 0.7035982
## [3,] 1.7616203     0 0.40693874 0.18301259 0.8248049 0.13015851     0 1.2291369
## [4,] 0.4069839     0 0.06528841 0.02160466 0.1692936 0.04295041     0 0.2880134
##      [,59]       [,60]     [,61]       [,62]     [,63] [,64]      [,65] [,66]
## [1,]     0 0.004638108 0.1918089 0.004156079 0.1617798     0 0.01778297     0
## [2,]     0 0.015287811 0.8980962 0.013698981 0.8138185     0 0.10568520     0
## [3,]     0 0.041292413 1.4042343 0.037000978 1.1749998     0 0.13475229     0
## [4,]     0 0.004258267 0.3768026 0.003815714 0.3468719     0 0.04446629     0
##      [,67]     [,68]     [,69] [,70]     [,71]     [,72] [,73] [,74]     [,75]
## [1,]     0 0.3814806 0.5795461     0 0.0933279 0.1105494     0     0 0.1896020
## [2,]     0 1.5836765 2.4767262     0 0.4377938 0.4248225     0     0 0.9532063
## [3,]     0 2.8265872 4.2823518     0 0.6831191 0.8248049     0     0 1.3771657
## [4,]     0 0.6449249 1.0163004     0 0.1837575 0.1692936     0     0 0.4062318
##          [,76]      [,77]     [,78]     [,79]     [,80]
## [1,] 0.0768136 0.06771266 0.3225259 0.1962227 0.1672971
## [2,] 0.4292548 0.24329585 1.3875661 0.7878759 0.6760431
## [3,] 0.5704944 0.50802172 2.3816525 1.4583714 1.2426711
## [4,] 0.1819626 0.09496836 0.5703490 0.3179440 0.2732988
## 
## $linear_cache
## $linear_cache$A
##      [,1] [,2] [,3]      [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]
## [1,]    0    0    0 0.2390594    0    0 0.1394750    0    0 0.1094459
## [2,]    0    0    0 0.9694026    0    0 0.5366553    0    0 0.4523776
## [3,]    0    0    0 1.7751546    0    0 1.0405052    0    0 0.8112707
## [4,]    0    0    0 0.3922693    0    0 0.2139389    0    0 0.1840082
##           [,11]      [,12] [,13]      [,14]     [,15]       [,16] [,17] [,18]
## [1,] 0.03780772 0.02779534     0 0.09663827 0.1101787 0.006558136     0     0
## [2,] 0.22469340 0.16518930     0 0.35512861 0.5795617 0.038975370     0     0
## [3,] 0.28649186 0.21062207     0 0.72372196 0.8029530 0.049694945     0     0
## [4,] 0.09453813 0.06950221     0 0.13961364 0.2475918 0.016398606     0     0
##          [,19] [,20] [,21] [,22] [,23]     [,24] [,25]     [,26] [,27]
## [1,] 0.2390594     0     0     0     0 0.1244604     0 0.1236077     0
## [2,] 0.9694026     0     0     0     0 0.4945164     0 0.6476668     0
## [3,] 1.7751546     0     0     0     0 0.9258879     0 0.8997446     0
## [4,] 0.3922693     0     0     0     0 0.1989735     0 0.2768292     0
##            [,28]      [,29]     [,30] [,31] [,32]     [,33]      [,34]
## [1,] 0.003192021 0.03762734 0.4660505     0     0 0.3236294 0.07030663
## [2,] 0.010521320 0.10343833 1.9742850     0     0 1.3600110 0.41515799
## [3,] 0.028418109 0.29303607 3.4466194     0     0 2.3951868 0.53161881
## [4,] 0.002930608 0.03505410 0.8082900     0     0 0.5556344 0.17480807
##           [,35] [,36] [,37]      [,38]     [,39]     [,40]       [,41] [,42]
## [1,] 0.03762734     0     0 0.08162374 0.4221104 0.1255639 0.006084195     0
## [2,] 0.10343833     0     0 0.31298976 1.8203134 0.4669614 0.020054302     0
## [3,] 0.29303607     0     0 0.60910470 3.1163019 0.9394222 0.054166717     0
## [4,] 0.03505410     0     0 0.12464831 0.7486794 0.1842589 0.005585925     0
##          [,43] [,44]      [,45] [,46]     [,47]     [,48] [,49]      [,50]
## [1,] 0.2463498     0 0.03104239     0 0.0499555 0.1617798     0 0.09663827
## [2,] 1.2044269     0 0.09026443     0 0.2930446 0.8138185     0 0.35512861
## [3,] 1.7950320     0 0.25179723     0 0.3769113 1.1749998     0 0.72372196
## [4,] 0.5102370     0 0.02879780     0 0.1234878 0.3468719     0 0.13961364
##          [,51] [,52]      [,53]      [,54]     [,55]      [,56] [,57]     [,58]
## [1,] 0.2379560     0 0.05380158 0.02321903 0.1105494 0.01717674     0 0.1661937
## [2,] 0.9969576     0 0.17360193 0.06490235 0.4248225 0.10208234     0 0.7035982
## [3,] 1.7616203     0 0.40693874 0.18301259 0.8248049 0.13015851     0 1.2291369
## [4,] 0.4069839     0 0.06528841 0.02160466 0.1692936 0.04295041     0 0.2880134
##      [,59]       [,60]     [,61]       [,62]     [,63] [,64]      [,65] [,66]
## [1,]     0 0.004638108 0.1918089 0.004156079 0.1617798     0 0.01778297     0
## [2,]     0 0.015287811 0.8980962 0.013698981 0.8138185     0 0.10568520     0
## [3,]     0 0.041292413 1.4042343 0.037000978 1.1749998     0 0.13475229     0
## [4,]     0 0.004258267 0.3768026 0.003815714 0.3468719     0 0.04446629     0
##      [,67]     [,68]     [,69] [,70]     [,71]     [,72] [,73] [,74]     [,75]
## [1,]     0 0.3814806 0.5795461     0 0.0933279 0.1105494     0     0 0.1896020
## [2,]     0 1.5836765 2.4767262     0 0.4377938 0.4248225     0     0 0.9532063
## [3,]     0 2.8265872 4.2823518     0 0.6831191 0.8248049     0     0 1.3771657
## [4,]     0 0.6449249 1.0163004     0 0.1837575 0.1692936     0     0 0.4062318
##          [,76]      [,77]     [,78]     [,79]     [,80]
## [1,] 0.0768136 0.06771266 0.3225259 0.1962227 0.1672971
## [2,] 0.4292548 0.24329585 1.3875661 0.7878759 0.6760431
## [3,] 0.5704944 0.50802172 2.3816525 1.4583714 1.2426711
## [4,] 0.1819626 0.09496836 0.5703490 0.3179440 0.2732988
## 
## $linear_cache$W
##           [,1]     [,2]      [,3]      [,4]
## [1,] 0.2740802 0.632204 0.7860847 0.1559307
## 
## $linear_cache$b
##      [,1]
## [1,]    0
## 
## 
## $activation_cache
##      [,1] [,2] [,3]      [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]     [,11]
## [1,]  0.5  0.5  0.5 0.8942559  0.5  0.5 0.7736063  0.5  0.5 0.7275865 0.5968513
##         [,12] [,13]     [,14]     [,15]     [,16] [,17] [,18]     [,19] [,20]
## [1,] 0.571618   0.5 0.6988183 0.7439108 0.5170083   0.5   0.5 0.8942559   0.5
##      [,21] [,22] [,23]     [,24] [,25]    [,26] [,27]   [,28]     [,29]
## [1,]   0.5   0.5   0.5 0.7513079   0.5 0.767416   0.5 0.50758 0.5772573
##          [,30] [,31] [,32]     [,33]     [,34]     [,35] [,36] [,37]     [,38]
## [1,] 0.9853889   0.5   0.5 0.9487274 0.6741223 0.5772573   0.5   0.5 0.6722743
##          [,39]    [,40]     [,41] [,42]     [,43] [,44]     [,45] [,46]
## [1,] 0.9788123 0.749665 0.5144451   0.5 0.9104851   0.5 0.5666014   0.5
##          [,47]     [,48] [,49]     [,50]     [,51] [,52]     [,53]     [,54]
## [1,] 0.6258597 0.8229666   0.5 0.6988183 0.8950827   0.5 0.6117244 0.5485039
##          [,55]     [,56] [,57]     [,58] [,59]     [,60]     [,61]     [,62]
## [1,] 0.7258441 0.5444468   0.5 0.8177957   0.5 0.5110131 0.8560612 0.5098689
##          [,63] [,64]     [,65] [,66] [,67]     [,68]     [,69] [,70]     [,71]
## [1,] 0.8229666   0.5 0.5460067   0.5   0.5 0.9685763 0.9947768   0.5 0.7043288
##          [,72] [,73] [,74]     [,75]     [,76]     [,77]     [,78]     [,79]
## [1,] 0.7258441   0.5   0.5 0.8582098 0.6833608 0.6425681 0.9491525 0.8516827
##          [,80]
## [1,] 0.8164848

0.4.4 Backward Propagation.

Backprop is used to calculate the gradient of the loss function wrt the parameters.

  • LINEAR backward
  • LINEAR -> ACTIVATION backward where ACTIVATION computes the derivative of either the ReLU or sigmoid activation
  • [LINEAR -> RELU] × (L-1) -> LINEAR -> SIGMOID backward (whole model)
Figure 2: Backpropagation.

Figure 2: Backpropagation.

Define the relu_backward function.

Define the sigmoid_backward function.

For layer l, the linear part is: Z[l]=W[l]A[l1]+b[l] (followed by an activation).

Suppose you have already calculated the derivative dZ[l]=LZ[l]. You want to get (dW[l],db[l]dA[l1]).

Figure 2: Backpropagation.

Figure 2: Backpropagation.

The three outputs (dW[l],db[l],dA[l]) are computed using the input dZ[l].Here are the formulas you need:

(8)dW[l]=LW[l]=1mdZ[l]A[l1]T

(9)db[l]=Lb[l]=1mi=1mdZ[l](i)

(10)dA[l1]=LA[l1]=W[l]TdZ[l]

## dim dZ3:
## 1 80
## 
## Linear Backward 3:
## 
## $dA_prev
## [1]  4 80
## 
## $dW
## [1] 1 4
## 
## $db
## [1] 1 1
## 
## Linear Backward 2:
## $dA_prev
## [1]  3 80
## 
## $dW
## [1] 4 3
## 
## $db
## [1] 4 1
## 
## Linear Backward 1:
## $dA_prev
## [1]  2 80
## 
## $dW
## [1] 3 2
## 
## $db
## [1] 3 1

In L_model_backward function, we will iterate through all the hidden layers backward, starting from layer L. On each step, we will use the cached values for layer l to backpropagate through layer l.

Figure 2: L backward Propagation

Figure 2: L backward Propagation

Here, we will now implement backpropagation for the [LINEAR->RELU] × (L-1) -> LINEAR -> SIGMOID model.

## $dA3
##             [,1]       [,2]        [,3]      [,4]        [,5]        [,6]
## [1,] -0.12881975 0.12881975 -0.12881975 0.5339293 -0.12881975 -0.12881975
## [2,] -0.29714058 0.29714058 -0.29714058 1.2315819 -0.29714058 -0.29714058
## [3,] -0.36946567 0.36946567 -0.36946567 1.5313533 -0.36946567 -0.36946567
## [4,] -0.07328859 0.07328859 -0.07328859 0.3037650 -0.07328859 -0.07328859
##           [,7]        [,8]       [,9]     [,10]       [,11]       [,12]
## [1,] 0.2615377 -0.12881975 0.12881975 0.2209719 -0.10515623 -0.11058810
## [2,] 0.6032728 -0.29714058 0.29714058 0.5097022 -0.24255741 -0.25508676
## [3,] 0.7501116 -0.36946567 0.36946567 0.6337656 -0.30159676 -0.31717580
## [4,] 0.1487949 -0.07328859 0.07328859 0.1257161 -0.05982586 -0.06291618
##            [,13]     [,14]       [,15]       [,16]       [,17]       [,18]
## [1,] -0.12881975 0.2018425 -0.08045412 -0.12405560 -0.12881975 -0.12881975
## [2,] -0.29714058 0.4655776 -0.18557856 -0.28615140 -0.29714058 -0.29714058
## [3,] -0.36946567 0.5789009 -0.23074905 -0.35580169 -0.36946567 -0.36946567
## [4,] -0.07328859 0.1148330 -0.04577224 -0.07057815 -0.07328859 -0.07328859
##          [,19]       [,20]       [,21]       [,22]      [,23]     [,24]
## [1,] 0.5339293 -0.12881975 -0.12881975 -0.12881975 0.12881975 0.2400265
## [2,] 1.2315819 -0.29714058 -0.29714058 -0.29714058 0.29714058 0.5536544
## [3,] 1.5313533 -0.36946567 -0.36946567 -0.36946567 0.36946567 0.6884159
## [4,] 0.3037650 -0.07328859 -0.07328859 -0.07328859 0.07328859 0.1365568
##            [,25]       [,26]       [,27]      [,28]      [,29]     [,30]
## [1,] -0.12881975 -0.07733124 -0.12881975 0.13055836 0.14929783  3.712960
## [2,] -0.29714058 -0.17837521 -0.29714058 0.30115092 0.34437609  8.564455
## [3,] -0.36946567 -0.22179238 -0.36946567 0.37445214 0.42819848 10.649075
## [4,] -0.07328859 -0.04399557 -0.07328859 0.07427773 0.08493905  2.112390
##           [,31]       [,32]     [,33]       [,34]      [,35]      [,36]
## [1,] 0.12881975 -0.12881975 1.0756381 -0.09091711 0.14929783 0.12881975
## [2,] 0.29714058 -0.29714058 2.4811082 -0.20971289 0.34437609 0.29714058
## [3,] 0.36946567 -0.36946567 3.0850189 -0.26075777 0.42819848 0.36946567
## [4,] 0.07328859 -0.07328859 0.6119559 -0.05172489 0.08493905 0.07328859
##            [,37]     [,38]    [,39]     [,40]     [,41]      [,42]       [,43]
## [1,] -0.12881975 0.1871260 2.568138 0.2385918 0.1321772 0.12881975 -0.06158663
## [2,] -0.29714058 0.4316321 5.923766 0.5503451 0.3048849 0.29714058 -0.14205808
## [3,] -0.36946567 0.5366929 7.365633 0.6843010 0.3790950 0.36946567 -0.17663553
## [4,] -0.07328859 0.1064604 1.461074 0.1357405 0.0751987 0.07328859 -0.03503808
##            [,44]      [,45]       [,46]       [,47]       [,48]       [,49]
## [1,] -0.12881975 0.14605975 -0.12881975 -0.09942346 -0.07061345 -0.12881975
## [2,] -0.29714058 0.33690702 -0.29714058 -0.22933398 -0.16287968 -0.29714058
## [3,] -0.36946567 0.41891141 -0.36946567 -0.28515469 -0.20252519 -0.36946567
## [4,] -0.07328859 0.08309683 -0.07328859 -0.05656435 -0.04017365 -0.07328859
##          [,50]     [,51]      [,52]      [,53]      [,54]     [,55]       [,56]
## [1,] 0.2018425 0.5379501 0.12881975 0.16094104 0.14089646 0.2197007 -0.11696838
## [2,] 0.4655776 1.2408563 0.29714058 0.37123278 0.32499719 0.5067702 -0.26980377
## [3,] 0.5789009 1.5428852 0.36946567 0.46159218 0.40410268 0.6301198 -0.33547498
## [4,] 0.1148330 0.3060525 0.07328859 0.09156315 0.08015932 0.1249929 -0.06654607
##            [,57]     [,58]      [,59]      [,60]       [,61]      [,62]
## [1,] -0.12881975 0.3195830 0.12881975 0.13136251 -0.06699816 0.13109330
## [2,] -0.29714058 0.7371625 0.29714058 0.30300581 -0.15454054 0.30238484
## [3,] -0.36946567 0.9165905 0.36946567 0.37675851 -0.19215626 0.37598640
## [4,] -0.07328859 0.1818183 0.07328859 0.07473523 -0.03811683 0.07458207
##            [,63]       [,64]       [,65]       [,66]       [,67]     [,68]
## [1,] -0.07061345 -0.12881975 -0.11658580 -0.12881975 -0.12881975 1.7396089
## [2,] -0.16287968 -0.29714058 -0.26892129 -0.29714058 -0.29714058 4.0126487
## [3,] -0.20252519 -0.36946567 -0.33437770 -0.36946567 -0.36946567 4.9893419
## [4,] -0.04017365 -0.07328859 -0.06632841 -0.07328859 -0.07328859 0.9897045
##          [,69]       [,70]       [,71]     [,72]      [,73]       [,74]
## [1,] 10.341757 -0.12881975 -0.08615074 0.2197007 0.12881975 -0.12881975
## [2,] 23.854693 -0.29714058 -0.19871860 0.5067702 0.29714058 -0.29714058
## [3,] 29.661012 -0.36946567 -0.24708742 0.6301198 0.36946567 -0.36946567
## [4,]  5.883669 -0.07328859 -0.04901319 0.1249929 0.07328859 -0.07328859
##            [,75]       [,76]      [,77]     [,78]     [,79]     [,80]
## [1,] -0.06677242 -0.08941752 0.17319913 1.0844270 0.3873839 0.3174612
## [2,] -0.15401983 -0.20625387 0.39950776 2.5013808 0.8935547 0.7322683
## [3,] -0.19150881 -0.25645681 0.49674939 3.1102260 1.1110491 0.9105050
## [4,] -0.03798840 -0.05087173 0.09853707 0.6169561 0.2203919 0.1806112
## 
## $dW3
##             [,1]     [,2]     [,3]      [,4]
## Species 0.496205 2.099032 3.669874 0.8591043
## 
## $db3
##          [,1]
## [1,] 1.059136
## 
## $dA2
##      [,1] [,2] [,3]      [,4] [,5] [,6]     [,7] [,8] [,9]     [,10]      [,11]
## [1,]    0    0    0 1.4830033    0    0 0.726428    0    0 0.6137554 -0.2920743
## [2,]    0    0    0 0.7450906    0    0 0.364972    0    0 0.3083630 -0.1467440
## [3,]    0    0    0 2.2674835    0    0 1.110694    0    0 0.9384202 -0.4465760
##           [,12] [,13]     [,14]      [,15]      [,16] [,17] [,18]     [,19]
## [1,] -0.3071615     0 0.5606231 -0.2234635 -0.3445678     0     0 1.4830033
## [2,] -0.1543241     0 0.2816682 -0.1122725 -0.1731178     0     0 0.7450906
## [3,] -0.4696440     0 0.8571818 -0.3416714 -0.5268376     0     0 2.2674835
##      [,20] [,21] [,22] [,23]     [,24] [,25]      [,26] [,27]     [,28]
## [1,]     0     0     0     0 0.6666803     0 -0.2147896     0 0.3626294
## [2,]     0     0     0     0 0.3349535     0 -0.1079146     0 0.1821923
## [3,]     0     0     0     0 1.0193413     0 -0.3284092     0 0.5544534
##          [,29]     [,30] [,31] [,32]    [,33]      [,34]     [,35] [,36] [,37]
## [1,] 0.4146788 10.312847     0     0 2.987614 -0.2525248 0.4146788     0     0
## [2,] 0.2083429  5.181381     0     0 1.501037 -0.1268735 0.2083429     0     0
## [3,] 0.6340359 15.768145     0     0 4.568005 -0.3861055 0.6340359     0     0
##          [,38]     [,39]     [,40]     [,41] [,42]       [,43] [,44]     [,45]
## [1,] 0.5197476  7.133074 0.6626953 0.3671257     0 -0.17105854     0 0.4056850
## [2,] 0.2611316  3.583799 0.3329514 0.1844513     0 -0.08594324     0 0.2038242
## [3,] 0.7946841 10.906333 1.0132484 0.5613281     0 -0.26154522     0 0.6202845
##      [,46]      [,47]       [,48] [,49]     [,50]     [,51] [,52]     [,53]
## [1,]     0 -0.2761514 -0.19613078     0 0.5606231 1.4941710     0 0.4470181
## [2,]     0 -0.1387440 -0.09854003     0 0.2816682 0.7507015     0 0.2245909
## [3,]     0 -0.4222302 -0.29988018     0 0.8571818 2.2845587     0 0.6834821
##          [,54]     [,55]      [,56] [,57]     [,58] [,59]     [,60]       [,61]
## [1,] 0.3913438 0.6102248 -0.3248829     0 0.8876507     0 0.3648630 -0.18608923
## [2,] 0.1966190 0.3065892 -0.1632277     0 0.4459735     0 0.1833145 -0.09349496
## [3,] 0.5983571 0.9330220 -0.4967397     0 1.3572007     0 0.5578684 -0.28452685
##          [,62]       [,63] [,64]      [,65] [,66] [,67]    [,68]    [,69] [,70]
## [1,] 0.3641152 -0.19613078     0 -0.3238203     0     0 4.831811 28.72451     0
## [2,] 0.1829388 -0.09854003     0 -0.1626938     0     0 2.427599 14.43177     0
## [3,] 0.5567252 -0.29988018     0 -0.4951149     0     0 7.387747 43.91923     0
##           [,71]     [,72] [,73] [,74]       [,75]      [,76]     [,77]    [,78]
## [1,] -0.2392860 0.6102248     0     0 -0.18546223 -0.2483596 0.4810653 3.012026
## [2,] -0.1202221 0.3065892     0     0 -0.09317994 -0.1247808 0.2416969 1.513302
## [3,] -0.3658637 0.9330220     0     0 -0.28356817 -0.3797370 0.7355396 4.605329
##          [,79]     [,80]
## [1,] 1.0759695 0.8817573
## [2,] 0.5405886 0.4430125
## [3,] 1.6451366 1.3481899
## 
## $dW2
##           [,1]      [,2]      [,3]
## [1,] 0.3779573 0.5864597 0.6154363
## [2,] 0.8718107 1.3527504 1.4195891
## [3,] 1.0840126 1.6820147 1.7651222
## [4,] 0.2150288 0.3336507 0.3501362
## 
## $db2
##          [,1]
## [1,] 2.120485
## [2,] 2.120485
## [3,] 2.120485
## [4,] 2.120485
## 
## $dA1
##      [,1] [,2] [,3]     [,4] [,5] [,6]      [,7] [,8] [,9]     [,10]      [,11]
## [1,]    0    0    0 1.670193    0    0 0.8181204    0    0 0.6912259 -0.2746920
## [2,]    0    0    0 3.638665    0    0 1.7823481    0    0 1.5058971 -0.4196493
##           [,12] [,13]     [,14]      [,15]      [,16] [,17] [,18]    [,19]
## [1,] -0.2888813     0 0.6313869 -0.2429370 -0.3240614     0     0 1.670193
## [2,] -0.4413264     0 1.3755327 -0.4208483 -0.4950714     0     0 3.638665
##      [,20] [,21] [,22] [,23]     [,24] [,25]      [,26] [,27]      [,28]
## [1,]     0     0     0     0 0.7508311     0 -0.2335073     0 0.01417142
## [2,]     0     0     0     0 1.6357523     0 -0.4045128     0 0.20680063
##           [,29]    [,30] [,31] [,32]    [,33]      [,34]      [,35] [,36] [,37]
## [1,] 0.07702119 11.61457     0     0 3.364722 -0.2745308 0.07702119     0     0
## [2,] 0.42164086 25.30338     0     0 7.330346 -0.4755793 0.42164086     0     0
##          [,38]     [,39]     [,40]      [,41] [,42]      [,43] [,44]     [,45]
## [1,] 0.5853521  8.033437 0.7463431 0.01434713     0 -0.1926502     0 0.0753507
## [2,] 1.2752416 17.501558 1.6259750 0.20936475     0 -0.4197056     0 0.4124960
##      [,46]      [,47]      [,48] [,49]     [,50]    [,51] [,52]     [,53]
## [1,]     0 -0.3002163 -0.2208871     0 0.6313869 1.682771     0 0.5034424
## [2,]     0 -0.5200752 -0.4812223     0 1.3755327 3.666066     0 1.0967941
##           [,54]     [,55]      [,56] [,57]     [,58] [,59]      [,60]
## [1,] 0.07268702 0.6872496 -0.3055480     0 0.9996931     0 0.01425871
## [2,] 0.39791407 1.4972345 -0.4667883     0 2.1779205     0 0.20807439
##           [,61]      [,62]      [,63] [,64]      [,65] [,66] [,67]    [,68]
## [1,] -0.2095781 0.01422949 -0.2208871     0 -0.3045486     0     0  5.44170
## [2,] -0.4565845 0.20764797 -0.4812223     0 -0.4652615     0     0 11.85523
##         [,69] [,70]      [,71]     [,72] [,73] [,74]      [,75]      [,76]
## [1,] 32.35023     0 -0.2694896 0.6872496     0     0 -0.2088719 -0.2700027
## [2,] 70.47785     0 -0.5871071 1.4972345     0     0 -0.4550461 -0.4677350
##          [,77]    [,78]    [,79]     [,80]
## [1,] 0.5417871 3.392214 1.211782 0.9930559
## [2,] 1.1803315 7.390241 2.639976 2.1634607
## 
## $dW1
##      Sepal.Length Sepal.Width
## [1,]  -0.03778349   1.8434129
## [2,]  -0.02028630   0.9275388
## [3,]  -0.05709898   2.8187046
## 
## $db1
##          [,1]
## [1,] 2.624512
## [2,] 2.624512
## [3,] 2.624512

0.4.5 Update Parameters.

## $W1
##      Sepal.Length Sepal.Width
## [1,]   0.04285797   0.3859396
## [2,]   0.29393056   0.7959609
## [3,]   0.62081683   0.6578336
## 
## $b1
##            [,1]
## [1,] -0.2624512
## [2,] -0.2624512
## [3,] -0.2624512
## 
## $W2
##            [,1]         [,2]       [,3]
## [1,] 0.04234196  0.003907636 0.05061452
## [2,] 0.17696322 -0.024029181 0.52460332
## [3,] 0.60505309  0.195219223 0.67337768
## [4,] 0.05207187  0.026409725 0.24543763
## 
## $b2
##            [,1]
## [1,] -0.2120485
## [2,] -0.2120485
## [3,] -0.2120485
## [4,] -0.2120485
## 
## $W3
##              [,1]      [,2]      [,3]       [,4]
## Species 0.2244597 0.4223009 0.4190973 0.07002028
## 
## $b3
##            [,1]
## [1,] -0.1059136

0.4.7 Make predictions.

##      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14]
## [1,]    1    1    1    1    1    0    0    0    1     1     0     1     1     0
##      [,15] [,16] [,17] [,18] [,19] [,20]
## [1,]     1     1     1     0     1     0
##         [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
## Species    0    0    0    0    1    1    1    0    0     0     0     1     0
##         [,14] [,15] [,16] [,17] [,18] [,19] [,20]
## Species     0     1     1     1     1     1     1

0.4.8 Calculate accuracy.

##       y_pred
## y_test 0 1
##      0 3 7
##      1 4 6
## We are getting an accuracy of 45 %.